{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "try:\n",
    "    import cPickle as pickle\n",
    "except:\n",
    "    import pickle\n",
    "import pandas as pd\n",
    "import mxnet as mx\n",
    "import wget\n",
    "import time\n",
    "import os.path\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "import logging\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "import queue as Queue\n",
    "import functools\n",
    "import threading\n",
    "import os.path\n",
    "from mxnet.io import DataBatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\\\"/\\\\|_@#$%^&*~`+ =<>()[]{}\")\n",
    "FEATURE_LEN = 1014\n",
    "BATCH_SIZE = 128\n",
    "NUM_FILTERS = 256\n",
    "DATA_SHAPE = (BATCH_SIZE, 1, FEATURE_LEN, len(ALPHABET))\n",
    "\n",
    "ctx = mx.gpu(2)\n",
    "EPOCHS = 10\n",
    "SD = 0.05  # std for gaussian distribution\n",
    "INITY = mx.init.Normal(sigma=SD)\n",
    "LR = 0.01\n",
    "MOMENTUM = 0.9\n",
    "WDECAY = 0.00001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# logging\n",
    "logger = logging.getLogger()\n",
    "fhandler = logging.FileHandler(filename='crepe_dbp.log', mode='a')\n",
    "formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "fhandler.setFormatter(formatter)\n",
    "logger.addHandler(fhandler)\n",
    "logger.setLevel(logging.DEBUG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_file(infile):\n",
    "    print(\"processing data frame: %s\" % infile)\n",
    "    # load data into dataframe\n",
    "    df = pd.read_csv(infile,\n",
    "                     header=None,\n",
    "                     names=['sentiment', 'summary', 'text'])\n",
    "    # concat summary, review; trim to 1014 char; reverse; lower\n",
    "    df['rev'] = df.apply(lambda x: \"%s %s\" % (x['summary'], x['text']), axis=1)\n",
    "    df.rev = df.rev.str[:FEATURE_LEN].str[::-1].str.lower()\n",
    "    # store class as nparray\n",
    "    y_split = np.asarray(df.sentiment, dtype='int')\n",
    "    print(\"finished processing data frame: %s\" % infile)\n",
    "    print(\"data contains %d obs, each epoch will contain %d batches\" % (df.shape[0], df.shape[0]//BATCH_SIZE))\n",
    "    return df.rev, y_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_data_frame(X_data, y_data, batch_size=128, shuffle=False):\n",
    "    \"\"\"\n",
    "    For low RAM this methods allows us to keep only the original data\n",
    "    in RAM and calculate the features (which are orders of magnitude bigger\n",
    "    on the fly). This keeps only 10 batches worth of features in RAM using\n",
    "    asynchronous programing and yields one DataBatch() at a time.\n",
    "    \"\"\"\n",
    "\n",
    "    if shuffle:\n",
    "        idx = X_data.index\n",
    "        assert len(idx) == len(y_data)\n",
    "        rnd = np.random.permutation(idx)\n",
    "        X_data = X_data.reindex(rnd)\n",
    "        y_data = y_data[rnd]\n",
    "\n",
    "    # Dictionary to create character vectors\n",
    "    hashes = {} \n",
    "    for index, letter in enumerate(ALPHABET):\n",
    "        hashes[letter] = np.zeros(len(ALPHABET), dtype='bool')\n",
    "        hashes[letter][index] = True\n",
    "\n",
    "    # Yield processed batches asynchronously\n",
    "    # Buffy 'batches' at a time\n",
    "    def async_prefetch_wrp(iterable, buffy=1):#buffy=30\n",
    "        poison_pill = object()\n",
    "\n",
    "        def worker(q, it):\n",
    "            for item in it:\n",
    "                q.put(item)\n",
    "            q.put(poison_pill)\n",
    "\n",
    "        queue = Queue.Queue(buffy)\n",
    "        it = iter(iterable)\n",
    "        thread = threading.Thread(target=worker, args=(queue, it))\n",
    "        thread.daemon = True\n",
    "        thread.start()\n",
    "        while True:\n",
    "            item = queue.get()\n",
    "            if item == poison_pill:\n",
    "                return\n",
    "            else:\n",
    "                yield item\n",
    "\n",
    "    # Async wrapper around\n",
    "    def async_prefetch(func):\n",
    "        @functools.wraps(func)\n",
    "        def wrapper(*args, **kwds):\n",
    "            return async_prefetch_wrp(func(*args, **kwds))\n",
    "        return wrapper\n",
    "\n",
    "    @async_prefetch\n",
    "    def feature_extractor(dta, val):\n",
    "        # Yield mini-batch amount of character vectors\n",
    "        X_split = np.zeros([batch_size, 1, FEATURE_LEN, len(ALPHABET)], dtype='bool')\n",
    "        for ti, tx in enumerate(dta):\n",
    "            chars = list(tx)\n",
    "            print(tx)\n",
    "            for ci, ch in enumerate(chars):\n",
    "                if ch in hashes:\n",
    "                    X_split[ti % batch_size][0][ci] = hashes[ch].copy()\n",
    "            # No padding -> only complete batches processed\n",
    "            if (ti + 1) % batch_size == 0:\n",
    "                #yield mx.nd.array(X_split), mx.nd.array(val[ti + 1 - batch_size:ti + 1])\n",
    "                yield X_split, val[ti + 1 - batch_size:ti + 1]\n",
    "                X_split = np.zeros([batch_size, 1, FEATURE_LEN, len(ALPHABET)], dtype='bool')\n",
    "\n",
    "    # Yield one mini-batch at a time and asynchronously process to keep 4 in queue\n",
    "    for Xsplit, ysplit in feature_extractor(X_data, y_data):\n",
    "        #yield DataBatch(data=[Xsplit], label=[ysplit])\n",
    "        yield Xsplit, ysplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def create_crepe():\n",
    "    \"\"\"\n",
    "    Replicating: https://github.com/zhangxiangxiao/Crepe/blob/master/train/config.lua\n",
    "    \"\"\"\n",
    "    input_x = mx.sym.Variable('data')  # placeholder for input\n",
    "    input_y = mx.sym.Variable('softmax_label')  # placeholder for output\n",
    "    # 1. alphabet x 1014\n",
    "    conv1 = mx.symbol.Convolution(data=input_x, kernel=(7, 69), num_filter=NUM_FILTERS)\n",
    "    relu1 = mx.symbol.Activation(data=conv1, act_type=\"relu\")\n",
    "    pool1 = mx.symbol.Pooling(data=relu1, pool_type=\"max\", kernel=(3, 1), stride=(3, 1))\n",
    "    # 2. 336 x 256\n",
    "    conv2 = mx.symbol.Convolution(data=pool1, kernel=(7, 1), num_filter=NUM_FILTERS)\n",
    "    relu2 = mx.symbol.Activation(data=conv2, act_type=\"relu\")\n",
    "    pool2 = mx.symbol.Pooling(data=relu2, pool_type=\"max\", kernel=(3, 1), stride=(3, 1))\n",
    "    # 3. 110 x 256\n",
    "    conv3 = mx.symbol.Convolution(data=pool2, kernel=(3, 1), num_filter=NUM_FILTERS)\n",
    "    relu3 = mx.symbol.Activation(data=conv3, act_type=\"relu\")\n",
    "    # 4. 108 x 256\n",
    "    conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 1), num_filter=NUM_FILTERS)\n",
    "    relu4 = mx.symbol.Activation(data=conv4, act_type=\"relu\")\n",
    "    # 5. 106 x 256\n",
    "    conv5 = mx.symbol.Convolution(data=relu4, kernel=(3, 1), num_filter=NUM_FILTERS)\n",
    "    relu5 = mx.symbol.Activation(data=conv5, act_type=\"relu\")\n",
    "    # 6. 104 x 256\n",
    "    conv6 = mx.symbol.Convolution(data=relu5, kernel=(3, 1), num_filter=NUM_FILTERS)\n",
    "    relu6 = mx.symbol.Activation(data=conv6, act_type=\"relu\")\n",
    "    pool6 = mx.symbol.Pooling(data=relu6, pool_type=\"max\", kernel=(3, 1), stride=(3, 1))\n",
    "    # 34 x 256\n",
    "    flatten = mx.symbol.Flatten(data=pool6)\n",
    "    # 7.  8704\n",
    "    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=1024)\n",
    "    act_fc1 = mx.symbol.Activation(data=fc1, act_type=\"relu\")\n",
    "    drop1 = mx.sym.Dropout(act_fc1, p=0.5)\n",
    "    # 8. 1024\n",
    "    fc2 = mx.symbol.FullyConnected(data=drop1, num_hidden=1024)\n",
    "    act_fc2 = mx.symbol.Activation(data=fc2, act_type=\"relu\")\n",
    "    drop2 = mx.sym.Dropout(act_fc2, p=0.5)\n",
    "    # 9. 1024\n",
    "    fc3 = mx.symbol.FullyConnected(data=drop2, num_hidden=NOUTPUT)\n",
    "    crepe = mx.symbol.SoftmaxOutput(data=fc3, label=input_y, name=\"softmax\")\n",
    "    return crepe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Visualise symbol (for crepe)\n",
    "crepe = create_crepe()\n",
    "\n",
    "#a = mx.viz.plot_network(crepe)\n",
    "#a.render('Crepe Model')\n",
    "#a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def save_check_point(mod_arg, mod_aux, pre, epoch):\n",
    "    \"\"\"\n",
    "    Save model each epoch, load as:\n",
    "\n",
    "    sym, arg_params, aux_params = \\\n",
    "        mx.model.load_checkpoint(model_prefix, n_epoch_load)\n",
    "\n",
    "    # assign parameters\n",
    "    mod.set_params(arg_params, aux_params)\n",
    "\n",
    "    OR\n",
    "\n",
    "    mod.fit(..., arg_params=arg_params, aux_params=aux_params,\n",
    "            begin_epoch=n_epoch_load)\n",
    "    \"\"\"\n",
    "\n",
    "    save_dict = {('arg:%s' % k): v for k, v in mod_arg.items()}\n",
    "    save_dict.update({('aux:%s' % k): v for k, v in mod_aux.items()})\n",
    "    param_name = '%s-%04d.pk' % (pre, epoch)\n",
    "    pickle.dump(save_dict, open(param_name, \"wb\"))\n",
    "    print('Saved checkpoint to \\\"%s\\\"' % param_name)\n",
    "    print('Saving model with mxnet notation')\n",
    "    mx.callback.do_checkpoint(pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_model(prefix, filename):\n",
    "    print(\"Initializing model\")\n",
    "    # Create mx.mod.Module()\n",
    "    cnn = create_crepe()\n",
    "    mod = mx.mod.Module(cnn, context=ctx)\n",
    "    \n",
    "    # Bind shape\n",
    "    mod.bind(data_shapes=[('data', DATA_SHAPE)],\n",
    "             label_shapes=[('softmax_label', (BATCH_SIZE,))])\n",
    "\n",
    "    # Initialise parameters and optimiser\n",
    "    mod.init_params(mx.init.Normal(sigma=SD))\n",
    "    mod.init_optimizer(optimizer='sgd',\n",
    "                       optimizer_params={\n",
    "                           \"learning_rate\": LR,\n",
    "                           \"momentum\": MOMENTUM,\n",
    "                           \"wd\": WDECAY,\n",
    "                           \"rescale_grad\": 1.0/BATCH_SIZE\n",
    "                       })\n",
    "\n",
    "    print(\"Loading file\")\n",
    "    # Load Data\n",
    "    X_train, y_train = load_file(filename)\n",
    "\n",
    "    # Train\n",
    "    print(\"Alphabet %d characters: \" % len(ALPHABET), ALPHABET)\n",
    "    print(\"started training\")\n",
    "    tic = time.time()\n",
    "\n",
    "    # Evaluation metric:\n",
    "    metric = mx.metric.Accuracy()\n",
    "\n",
    "    # Train EPOCHS\n",
    "    for epoch in range(EPOCHS):\n",
    "        t = 0\n",
    "        metric.reset()\n",
    "        tic_in = time.time()\n",
    "        for batch in load_data_frame(X_data=X_train,\n",
    "                                     y_data=y_train,\n",
    "                                     batch_size=BATCH_SIZE,\n",
    "                                     shuffle=True):\n",
    "            # Push data forwards and update metric\n",
    "            mod.forward_backward(batch)\n",
    "            mod.update()\n",
    "            mod.update_metric(metric, batch.label)\n",
    "\n",
    "            # For training + testing\n",
    "            #mod.forward(batch, is_train=True)\n",
    "            #mod.update_metric(metric, batch.label)\n",
    "            # Get weights and update\n",
    "            # For training only\n",
    "            #mod.backward()\n",
    "            #mod.update()\n",
    "            # Log every 50 batches = 128*50 = 6400\n",
    "            t += 1\n",
    "            if t % 50 == 0:\n",
    "                train_t = time.time() - tic_in\n",
    "                metric_m, metric_v = metric.get()\n",
    "                print(\"epoch: %d iter: %d metric(%s): %.4f dur: %.0f\" % (epoch, t, metric_m, metric_v, train_t))\n",
    "\n",
    "        # Checkpoint\n",
    "        print(\"Saving checkpoint\")\n",
    "        arg_params, aux_params = mod.get_params()\n",
    "        save_check_point(mod_arg=arg_params,\n",
    "                         mod_aux=aux_params,\n",
    "                         pre=prefix,\n",
    "                         epoch=epoch)\n",
    "        print(\"Finished epoch %d\" % epoch)\n",
    "\n",
    "    print(\"Done. Finished in %.0f seconds\" % (time.time() - tic))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "NOUTPUT = 14  # Classes\n",
    "model_prefix = 'crepe_dbpedia_prefetch'\n",
    "train_file = '/datadrive/nlp/dbpedia_train.csv'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processing data frame: /datadrive/nlp/categories_train_big.csv\n",
      "finished processing data frame: /datadrive/nlp/categories_train_big.csv\n",
      "data contains 2379999 obs, each epoch will contain 18593 batches\n"
     ]
    }
   ],
   "source": [
    "#test data\n",
    "X_train, y_train = load_file(train_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".maj siht fo tuo em gnitteg ni lufpleh oot t'nsaw koob eht tub ,3 lla ro ,em ,koob eht ,7 swodniw htiw seil tluaf eht fi wonk t'nod i .sdrawretfa yllaudividni eno hcae lebal ot evah i dna meht sdaolnwod yllacitamotua 7 swodniw ym wohemos .spuorg ni meht gnilebal dna serutcip gnidaolnwod fo erudecorp elpmis eht gnidrager depmuts neeb ev'i ?hcum oot tcepxe i did ?hcum oot tcepxe i did\n",
      "deppots nehw od ot tahw wonk lliw yeht yllufepoh dna sdik ruoy ot siht wohs .deppots nehw evaheb otwoh wonk dluohs sevird ohw enoyreve.esnes nommoc s'ti yllaer tub .doog srevird gnuoy dna srevird wen rof doog\n",
      "label =  [0]\n",
      "data shape =  (1, 1, 1014, 69)\n",
      "data = [[[[0 0 0 ..., 0 0 0]\n",
      "   [0 0 0 ..., 0 0 0]\n",
      "   [1 0 0 ..., 0 0 0]\n",
      "   ..., \n",
      "   [0 0 0 ..., 0 0 0]\n",
      "   [0 0 0 ..., 0 0 0]\n",
      "   [0 0 0 ..., 0 0 0]]]]\n",
      "dint shape =  (1014, 69)\n",
      "<class 'numpy.ndarray'>\n",
      ".detide llew dna daer htrae ot nwod yrev !reverof gnitirw speek ehs epoh yllaer i !!skoob reh lla evol  !retirw gnidnatstuo !!!!koob emosewa rehtona\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for batch in load_data_frame(X_data=X_train, y_data=y_train, batch_size=1,shuffle=False):\n",
    "    #print(\"label = \", np.asarray(batch.label))\n",
    "    #print(\"data = \", np.asarray(batch.data))\n",
    "    d,l = batch\n",
    "    dint = np.asarray(d,dtype='int32')\n",
    "    break\n",
    "print(\"label = \", l)\n",
    "print(\"data shape = \", d.shape)\n",
    "print(\"data =\", dint)\n",
    "dint = np.reshape(dint, [1014, 69])\n",
    "print(\"dint shape = \", dint.shape)\n",
    "print(type(dint))\n",
    "#df = pd.DataFrame(dint)\n",
    "#df.to_csv(\"file_path.csv\")\n",
    "np.savetxt('first_batch.csv', dint,  delimiter=',', fmt='%d',)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_check_point(file_name):\n",
    "\n",
    "    # Load file\n",
    "    print(file_name)\n",
    "    save_dict = pickle.load(open(file_name, \"rb\"))\n",
    "    # Extract data from save\n",
    "    arg_params = {}\n",
    "    aux_params = {}\n",
    "    for k, v in save_dict.items():\n",
    "        tp, name = k.split(':', 1)\n",
    "        if tp == 'arg':\n",
    "            arg_params[name] = v\n",
    "        if tp == 'aux':\n",
    "            aux_params[name] = v\n",
    "\n",
    "    # Recreate model\n",
    "    cnn = create_crepe()\n",
    "    mod = mx.mod.Module(cnn, context=ctx)\n",
    "\n",
    "    # Bind shape\n",
    "    mod.bind(data_shapes=[('data', DATA_SHAPE)],\n",
    "             label_shapes=[('softmax_label', (BATCH_SIZE,))])\n",
    "\n",
    "    # assign parameters from save\n",
    "    mod.set_params(arg_params, aux_params)\n",
    "    print('Model loaded from disk')\n",
    "\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def test_model(pickled_model, filename):\n",
    "    \"\"\" This doesn't take too long but still seems it takes longer than\n",
    "    it should be taking ... \"\"\"\n",
    "\n",
    "    # Load saved model:\n",
    "    mod = load_check_point(pickled_model)\n",
    "    #assert mod.binded and mod.params_initialized\n",
    "\n",
    "    # Load data\n",
    "    X_test, y_test = load_file(filename)\n",
    "\n",
    "    # Score accuracy\n",
    "    metric = mx.metric.Accuracy()\n",
    "\n",
    "    # Test batches\n",
    "    for batch in load_data_frame(X_data=X_test,\n",
    "                                 y_data=y_test,\n",
    "                                 batch_size=BATCH_SIZE):\n",
    "\n",
    "        mod.forward(batch, is_train=False)\n",
    "        mod.update_metric(metric, batch.label)\n",
    "        metric_m, metric_v = metric.get()\n",
    "        print(\"TEST(%s): %.4f\" % (metric_m, metric_v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model_prefix = 'crepe_dbpedia_prefetch'\n",
    "model_epoch = 9\n",
    "model_pk = model_prefix + '_000' + str(model_epoch) + '.pk'\n",
    "test_file = '/datadrive/nlp/dbpedia_test.csv'\n",
    "test_model(model_pk, test_file)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
